import time
from collections import defaultdict
from dataclasses import asdict, dataclass, field

from django.conf import settings
from django.db import transaction
from django.db.models import F
from django.utils.timezone import now as timezone_now
from django.utils.translation import gettext as _

from analytics.lib.counts import COUNT_STATS, do_increment_logging_stat
from zerver.lib.exceptions import JsonableError
from zerver.lib.message import (
    bulk_access_messages,
    format_unread_message_details,
    get_raw_unread_data,
)
from zerver.lib.queue import queue_event_on_commit
from zerver.lib.stream_subscription import get_subscribed_stream_recipient_ids_for_user
from zerver.lib.topic import filter_by_topic_name_via_message
from zerver.lib.user_message import DEFAULT_HISTORICAL_FLAGS, create_historical_user_messages
from zerver.models import Message, Recipient, UserMessage, UserProfile
from zerver.tornado.django_api import send_event_on_commit, send_event_rollback_unsafe


@dataclass
class ReadMessagesEvent:
    messages: list[int]
    all: bool
    type: str = field(default="update_message_flags", init=False)
    op: str = field(default="add", init=False)
    operation: str = field(default="add", init=False)
    flag: str = field(default="read", init=False)


def do_mark_all_as_read(user_profile: UserProfile, *, timeout: float | None = None) -> int | None:
    start_time = time.monotonic()

    # First, we clear mobile push notifications.  This is safer in the
    # event that the below logic times out and we're killed.
    all_push_message_ids = list(
        UserMessage.objects.filter(
            user_profile=user_profile,
        )
        .extra(  # noqa: S610
            where=[UserMessage.where_active_push_notification()],
        )
        .values_list("message_id", flat=True)[0:10000]
    )
    do_clear_mobile_push_notifications_for_ids([user_profile.id], all_push_message_ids)

    batch_size = 2000
    count = 0
    while True:
        if timeout is not None and time.monotonic() >= start_time + timeout:
            return None

        with transaction.atomic(durable=True):
            query = (
                UserMessage.select_for_update_query()
                .filter(user_profile=user_profile)
                .extra(where=[UserMessage.where_unread()])[:batch_size]  # noqa: S610
            )
            # This updated_count is the same as the number of UserMessage
            # rows selected, because due to the FOR UPDATE lock, we're guaranteed
            # that all the selected rows will indeed be updated.
            # UPDATE queries don't support LIMIT, so we have to use a subquery
            # to do batching.
            updated_count = UserMessage.objects.filter(id__in=query).update(
                flags=F("flags").bitor(UserMessage.flags.read),
            )

            event_time = timezone_now()
            do_increment_logging_stat(
                user_profile,
                COUNT_STATS["messages_read::hour"],
                None,
                event_time,
                increment=updated_count,
            )
            do_increment_logging_stat(
                user_profile,
                COUNT_STATS["messages_read_interactions::hour"],
                None,
                event_time,
                increment=min(1, updated_count),
            )

            count += updated_count
            if updated_count < batch_size:
                break

    event = asdict(
        ReadMessagesEvent(
            messages=[],  # we don't send messages, since the client reloads anyway
            all=True,
        )
    )
    send_event_rollback_unsafe(user_profile.realm, event, [user_profile.id])

    return count


@transaction.atomic(durable=True)
def do_mark_stream_messages_as_read(
    user_profile: UserProfile, stream_recipient_id: int, topic_name: str | None = None
) -> int:
    query = (
        UserMessage.select_for_update_query()
        .filter(
            user_profile=user_profile,
            message__recipient_id=stream_recipient_id,
        )
        .extra(  # noqa: S610
            where=[UserMessage.where_unread()],
        )
    )

    if topic_name:
        query = filter_by_topic_name_via_message(
            query=query,
            topic_name=topic_name,
        )

    message_ids = list(query.values_list("message_id", flat=True))

    if len(message_ids) == 0:
        return 0

    count = query.update(
        flags=F("flags").bitor(UserMessage.flags.read),
    )

    event = asdict(
        ReadMessagesEvent(
            messages=message_ids,
            all=False,
        )
    )
    event_time = timezone_now()

    send_event_on_commit(user_profile.realm, event, [user_profile.id])
    do_clear_mobile_push_notifications_for_ids([user_profile.id], message_ids)

    do_increment_logging_stat(
        user_profile, COUNT_STATS["messages_read::hour"], None, event_time, increment=count
    )
    do_increment_logging_stat(
        user_profile,
        COUNT_STATS["messages_read_interactions::hour"],
        None,
        event_time,
        increment=min(1, count),
    )
    return count


@transaction.atomic(savepoint=False)
def do_mark_muted_user_messages_as_read(
    user_profile: UserProfile,
    muted_user: UserProfile,
) -> int:
    query = (
        UserMessage.select_for_update_query()
        .filter(user_profile=user_profile, message__sender=muted_user)
        .extra(where=[UserMessage.where_unread()])  # noqa: S610
    )
    message_ids = list(query.values_list("message_id", flat=True))

    if len(message_ids) == 0:
        return 0

    count = query.update(
        flags=F("flags").bitor(UserMessage.flags.read),
    )

    event = asdict(
        ReadMessagesEvent(
            messages=message_ids,
            all=False,
        )
    )
    event_time = timezone_now()

    send_event_on_commit(user_profile.realm, event, [user_profile.id])
    do_clear_mobile_push_notifications_for_ids([user_profile.id], message_ids)

    do_increment_logging_stat(
        user_profile, COUNT_STATS["messages_read::hour"], None, event_time, increment=count
    )
    do_increment_logging_stat(
        user_profile,
        COUNT_STATS["messages_read_interactions::hour"],
        None,
        event_time,
        increment=min(1, count),
    )
    return count


def do_update_mobile_push_notification(
    message: Message,
    prior_mention_user_ids: set[int],
    mentions_user_ids: set[int],
    stream_push_user_ids: set[int],
) -> None:
    # Called during the message edit code path to remove mobile push
    # notifications for users who are no longer mentioned following
    # the edit.  See #15428 for details.
    #
    # A perfect implementation would also support updating the message
    # in a sent notification if a message was edited to mention a
    # group rather than a user (or vice versa), though it is likely
    # not worth the effort to do such a change.
    if not message.is_channel_message:
        return

    remove_notify_users = prior_mention_user_ids - mentions_user_ids - stream_push_user_ids
    do_clear_mobile_push_notifications_for_ids(list(remove_notify_users), [message.id])


def do_clear_mobile_push_notifications_for_ids(
    user_profile_ids: list[int] | None, message_ids: list[int]
) -> None:
    if len(message_ids) == 0:
        return

    if user_profile_ids is not None:
        # This block gets executed in the following cases:
        # * Message(s) marked as read by a user
        # * A message edited to remove mention(s)
        if len(user_profile_ids) == 0:
            return

        # This supports clearing notifications for several users only for
        # the message-edit use case where we'll have a single message_id.
        assert len(user_profile_ids) == 1 or len(message_ids) == 1

        notifications_to_update = (
            UserMessage.objects.filter(
                message_id__in=message_ids,
                user_profile_id__in=user_profile_ids,
            )
            .extra(  # noqa: S610
                where=[UserMessage.where_active_push_notification()],
            )
            .values_list("user_profile_id", "message_id")
        )
    else:
        # This block handles clearing notifications when message(s) get deleted.
        notifications_to_update = (
            # Uses index: zerver_usermessage_message_active_mobile_push_notification_idx
            UserMessage.objects.filter(
                message_id__in=message_ids,
            )
            .extra(  # noqa: S610
                where=[UserMessage.where_active_push_notification()],
            )
            .values_list("user_profile_id", "message_id")
        )

    messages_by_user = defaultdict(list)
    for user_id, message_id in notifications_to_update:
        messages_by_user[user_id].append(message_id)

    for user_profile_id, event_message_ids in messages_by_user.items():
        notice = {
            "type": "remove",
            "user_profile_id": user_profile_id,
            "message_ids": event_message_ids,
        }
        if settings.MOBILE_NOTIFICATIONS_SHARDS > 1:  # nocoverage
            shard_id = user_profile_id % settings.MOBILE_NOTIFICATIONS_SHARDS + 1
            queue_event_on_commit(f"missedmessage_mobile_notifications_shard{shard_id}", notice)
        else:
            queue_event_on_commit("missedmessage_mobile_notifications", notice)


def do_update_message_flags(
    user_profile: UserProfile, operation: str, flag: str, messages: list[int]
) -> tuple[int, list[int]]:
    valid_flags = [item for item in UserMessage.flags if item not in UserMessage.NON_API_FLAGS]
    if flag not in valid_flags:
        raise JsonableError(_("Invalid flag: '{flag}'").format(flag=flag))
    if flag in UserMessage.NON_EDITABLE_FLAGS:
        raise JsonableError(_("Flag not editable: '{flag}'").format(flag=flag))
    if operation not in ("add", "remove"):
        raise JsonableError(
            _("Invalid message flag operation: '{operation}'").format(operation=operation)
        )
    is_adding = operation == "add"
    flagattr = getattr(UserMessage.flags, flag)
    flag_target = flagattr if is_adding else 0

    ignored_because_not_subscribed_channels = []
    with transaction.atomic(durable=True):
        if flag == "read" and not is_adding:
            # We have an invariant that all stream messages marked as
            # unread must be in streams the user is subscribed to.
            #
            # When marking as unread, we enforce this invariant by
            # ignoring any messages in streams the user is not
            # currently subscribed to.
            subscribed_recipient_ids = get_subscribed_stream_recipient_ids_for_user(user_profile)

            messages_in_unsubscribed_streams = set(
                # Uses index: zerver_message_pkey
                Message.objects.select_related("recipient")
                .filter(id__in=messages, recipient__type=Recipient.STREAM)
                .exclude(recipient_id__in=subscribed_recipient_ids)
                .values_list("id", "recipient__type_id")
            )

            message_ids_in_unsubscribed_streams = {
                message[0] for message in messages_in_unsubscribed_streams
            }

            ignored_because_not_subscribed_channels = list(
                {message[1] for message in messages_in_unsubscribed_streams}
            )

            messages = [
                message_id
                for message_id in messages
                if message_id not in message_ids_in_unsubscribed_streams
            ]

        ums = {
            um.message_id: um
            for um in UserMessage.select_for_update_query().filter(
                user_profile=user_profile, message_id__in=messages
            )
        }

        # Filter out rows that already have the desired flag.  We do
        # this here, rather than in the original database query,
        # because not all flags have database indexes and we want to
        # bound the cost of this operation.
        messages = [
            message_id
            for message_id in messages
            if (int(ums[message_id].flags) if message_id in ums else DEFAULT_HISTORICAL_FLAGS)
            & flagattr
            != flag_target
        ]
        count = len(messages)

        if DEFAULT_HISTORICAL_FLAGS & flagattr != flag_target:
            # When marking messages as read, creating "historical"
            # UserMessage rows would be a waste of storage, because
            # `flags.read | flags.historical` is exactly the flags we
            # simulate when processing a message for which a user has
            # access but no UserMessage row.
            #
            # Users can mutate flags for messages that don't have a
            # UserMessage yet.  Validate that the user is even allowed
            # to access these message_ids; if so, we will create
            # "historical" UserMessage rows for the messages in question.
            #
            # See create_historical_user_messages for a more detailed
            # explanation.
            historical_message_ids = set(messages) - set(ums.keys())
            historical_messages = bulk_access_messages(
                user_profile,
                list(
                    # Uses index: zerver_message_pkey
                    Message.objects.filter(id__in=historical_message_ids).prefetch_related(
                        "recipient"
                    )
                ),
                is_modifying_message=False,
            )
            if len(historical_messages) != len(historical_message_ids):
                raise JsonableError(_("Invalid message(s)"))

            create_historical_user_messages(
                user_id=user_profile.id,
                message_ids=list(historical_message_ids),
                flagattr=flagattr,
                flag_target=flag_target,
            )

        to_update = UserMessage.objects.filter(
            user_profile=user_profile, message_id__in=set(messages) & set(ums.keys())
        )
        if is_adding:
            to_update.update(flags=F("flags").bitor(flagattr))
        else:
            to_update.update(flags=F("flags").bitand(~flagattr))

        event = {
            "type": "update_message_flags",
            "op": operation,
            "operation": operation,
            "flag": flag,
            "messages": messages,
            "all": False,
        }

        if flag == "read" and not is_adding:
            # When removing the read flag (i.e. marking messages as
            # unread), extend the event with an additional object with
            # details on the messages required to update the client's
            # `unread_msgs` data structure.
            raw_unread_data = get_raw_unread_data(user_profile, messages)
            event["message_details"] = format_unread_message_details(
                user_profile.id, raw_unread_data
            )

        send_event_on_commit(user_profile.realm, event, [user_profile.id])

        if flag == "read" and is_adding:
            event_time = timezone_now()
            do_clear_mobile_push_notifications_for_ids([user_profile.id], messages)

            do_increment_logging_stat(
                user_profile, COUNT_STATS["messages_read::hour"], None, event_time, increment=count
            )
            do_increment_logging_stat(
                user_profile,
                COUNT_STATS["messages_read_interactions::hour"],
                None,
                event_time,
                increment=min(1, count),
            )

    return (count, ignored_because_not_subscribed_channels)
